1 기계학습 EDA 활용

기계학습을 탐색적 데이터 분석에 효과적으로 활용하기 위해서 예측변수(고객이탈)에 대해 유의적인 영향을 미치는 중요 변수를 식별해 내는 것이 중요하다. 이를 위해서 기계학습 알고리즘을 통해 중요변수를 추출해보자.

기계학습을 통해서 중요변수를 추출한 후에 단변량, 교차분석을 통해서 고객이탈 모형 개발에 사용될 데이터에 대해서 익숙해진다.

지도 기계학습 모형에 대한 자세한 사항은 xwMOOC 모형, 지도학습모형 → EDA - 포도주(wine)을 참조한다.

2 중요변수 추출

random forest 기계학습 알고리즘을 통해서 중요변수를 추출한다.

# 1. 데이터 ------
churn_m_df <- read_rds("data/churn_m_df.rds") %>% 
    sample_frac(0.1)

# 2. 중요변수 추출 기계학습 ------
## 2.1. 훈련/시험 데이터 분할 ------
idx <- createDataPartition(churn_m_df$churn, 
                           p = 0.8, 
                           list = FALSE, 
                           times = 1)

churn_train <- churn_m_df[ idx,]
churn_test  <- churn_m_df[-idx,]

## 2.2. 모형적합 ------
fit_ctrl <- trainControl(method = "repeatedcv", number = 5, repeats = 3)

churn_rf <- train(churn ~ ., 
                 data = churn_train, 
                 method = "rf", 
                 preProcess = c("scale", "center"),
                 trControl = fit_ctrl,
                 verbose = FALSE)

## 2.3. 모형성능평가 ------

test_predict <- predict(churn_rf, churn_test)
confusionMatrix(test_predict, churn_test$churn)
Confusion Matrix and Statistics

          Reference
Prediction No Yes
       No  98  18
       Yes  3  21
                                          
               Accuracy : 0.85            
                 95% CI : (0.7799, 0.9047)
    No Information Rate : 0.7214          
    P-Value [Acc > NIR] : 0.0002462       
                                          
                  Kappa : 0.5769          
 Mcnemar's Test P-Value : 0.0022502       
                                          
            Sensitivity : 0.9703          
            Specificity : 0.5385          
         Pos Pred Value : 0.8448          
         Neg Pred Value : 0.8750          
             Prevalence : 0.7214          
         Detection Rate : 0.7000          
   Detection Prevalence : 0.8286          
      Balanced Accuracy : 0.7544          
                                          
       'Positive' Class : No              
                                          
# 3. 모형 설명 -----
## 3.1. 중요변수 추출 -----
churn_rf_imp <- varImp(churn_rf, scale = TRUE)

churn_rf_imp$importance %>%
    as.data.frame() %>%
    rownames_to_column(var="variable") %>% 
    top_n(3, Overall) %>% 
    pull(variable)
[1] "tenure"          "monthly_charges" "total_charges"  
## 3.2. 중요변수 시각화 -----
churn_rf_imp$importance %>%
    as.data.frame() %>%
    rownames_to_column(var="variable") %>%
    ggplot(aes(x = reorder(variable, Overall), y = Overall)) +
    geom_bar(stat = "identity", fill = "#1F77B4", alpha = 0.8) +
    coord_flip() +
    labs(y="중요도", x="") +
    theme_minimal(base_family="NanumGothic")

기계학습 모형을 돌려서… 상위 4개 변수가 중요한 역할을 한 것으로 식별이 되었다고 볼 수 있다.

3 지도학습 모형 시각화

기계학습을 통해서 tenure, total_charges, monthly_charges, payment_method 변수가 다른 변수에 비해서 고객이탈 분류에 중요한 역할을 수행하는 것을 확인했다. 이를 염두에 두고 시각화를 진행해 보자.

3.1 정적 시각화

고객이탈 분류가 목적이기 때문에 예측변수와 예측에 동원되는 설명변수를 시각화한다. 우선 연속형 변수와 범주형 변수로 변수를 구분한다. 그리고 나서 범주형 변수의 경우 변주형 변수를 모아서 “고객이탈”과 결합하여 시각화한다.

## 범주형 변수
cat_var <- churn_m_df %>% 
    select_if(is.factor) %>% 
    colnames()

## 연속형 변수
cont_var <- churn_m_df %>% 
    select_if(is.numeric) %>% 
    colnames()


# 2. 탐색적 데이터 분석 ------
## 2.1. 정적 시각화 -----
### 범주형 변수 시각화
visualize_category <- function(variable) {
    # var <- enquo(variable)
    var <- sym(variable)
    churn_m_df %>% 
        select_if(is.factor) %>% 
        count(!!var, churn) %>% 
        ggplot(aes(x=!!var, y=n, fill=churn)) +
        geom_col() +
        labs(x=rlang::quo_text(var), y="") +
        scale_y_continuous(labels = scales::comma) +
        scale_fill_viridis_d() +
        theme(legend.position = "none") +
        scale_x_discrete(label = function(x) abbreviate(x, minlength=7))
}

# visualize_category(gender)
# visualize_category("churn")

cat_barplot_lst <- map(setdiff(cat_var, "churn"), visualize_category)

multiplot(plotlist=cat_barplot_lst, cols=4)

범주형 변수와 동일한 방식으로 연속형 변수를 고객이탈과 결합하여 시각화한다.

### 연속형 변수 시각화
y_p <- churn_m_df %>%
    ggplot(aes(x = churn, fill = churn)) +
    geom_bar(alpha = 0.8) +
    scale_fill_manual(values = c("red", "gray")) +
    guides(fill = FALSE)

x_p <- churn_m_df %>%
    select(churn, cont_var) %>% 
    gather(variable, value, -churn) %>%
    ggplot(aes(x = value, y = churn, color = churn, fill = churn)) +
    facet_wrap( ~ variable, scale = "free", ncol = 3) +
    scale_color_manual(values = c("red", "gray")) +
    scale_fill_manual(values = c("red", "gray")) +
    geom_density_ridges(alpha = 0.8) +
    guides(fill = FALSE, color = FALSE)

plot_grid(y_p, x_p, rel_widths = c(1,3))

3.2 인터랙티브 시각화

crosstalk를 통해 예측변수와 설명변수를 연결하고 이를 plotly 팩키지로 인터랙티브 시각화한다. 자세한 방식은 포도주 분류 인터랙티브 시각화를 참조한다.

4 교차분석

기계학습을 통해서 추출된 고객이탈에 중요도가 큰 변수 네개(tenure, total_charges, monthly_charges, payment_method)를 대상을 ggpairs() 함수로 교차분석을 수행한다.

## 기계학습 중요 변수
imp_var <- churn_m_df %>% 
    select(tenure, total_charges, monthly_charges, payment_method) %>% 
    colnames()

# 3. 교차 분석 ------
churn_m_df %>% 
    select(imp_var, churn) %>% 
    GGally::ggpairs(aes(color=churn, alpha=0.7))

crosstalkplotly를 결합하여 인터랙티브 교차분석 그래프도 가능하다.

## 인터랙티브 교차그래프
churn_sd <- churn_m_df %>% 
    select(imp_var, churn) %>% 
    sample_frac(0.1) %>% SharedData$new(.)

churn_xy_g <- GGally::ggpairs(churn_sd, aes(color=churn, alpha=0.7))
highlight(ggplotly(churn_xy_g), on = "plotly_selected")